import torch.nn.functional as F


def ce_loss(pred, label):
    loss = F.cross_entropy(pred, label)
    return loss


loss_dict = {'ce_loss': ce_loss}
